/* vim: set ft=cpp:
 *
 * Copyright (c) 2018- Apple Computer, Inc. All rights reserved.
 *
 * @APPLE_LICENSE_HEADER_START@
 *
 * The contents of this file constitute Original Code as defined in and
 * are subject to the Apple Public Source License Version 2.0 (the
 * "License").  You may not use this file except in compliance with the
 * License.  Please obtain a copy of the License at
 * http://www.apple.com/publicsource and read it before using this file.
 *
 * This Original Code and all software distributed under the License are
 * distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, EITHER
 * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
 * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE OR NON-INFRINGEMENT.  Please see the
 * License for the specific language governing rights and limitations
 * under the License.
 *
 * @APPLE_LICENSE_HEADER_END@
 */

// Since IOKit doesn't support stl for good reason, we would like to
// cherry-pick the more useful c++11 and after templates

#ifndef IOG_TL_IOLOCKS
#define IOG_TL_IOLOCKS

#include <_config>

#include <osutility>
#include <IOKit/IOLocks.h>

_IOG_START_NAMESPACE

// Declare a special type of IOSimpleLock that indicates interrupt blocking
struct IOSimpleLockInterrupt {};
inline IOSimpleLock* IOSimpleLockCastOperator(IOSimpleLockInterrupt *sli)
    { return reinterpret_cast<IOSimpleLock*>(sli); }
inline IOSimpleLockInterrupt* IOSimpleLockInterruptAlloc()
    { return reinterpret_cast<IOSimpleLockInterrupt*>(IOSimpleLockAlloc()); }
inline void IOSimpleLockInterruptFree(IOSimpleLockInterrupt* sli)
    { return IOSimpleLockFree(IOSimpleLockCastOperator(sli)); }

namespace locking_primitives {
// Take locks
inline IOInterruptState lock(IOLock* l)
    { IOLockLock(l); return 0; }
inline IOInterruptState lock(IORecursiveLock* l)
    { IORecursiveLockLock(l); return 0; }
inline IOInterruptState lock(IOSimpleLock* l)
    { IOSimpleLockLock(l); return 0; }
inline IOInterruptState lock(IOSimpleLockInterrupt* sli)
{
    auto *sl = reinterpret_cast<IOSimpleLock*>(sli);
    return IOSimpleLockLockDisableInterrupt(sl);
}

// Release locks
inline void unlock(IOLock* l, IOInterruptState)
    { IOLockUnlock(l); }
inline void unlock(IORecursiveLock* l, IOInterruptState)
    { IORecursiveLockUnlock(l); }
inline void unlock(IOSimpleLock* l, IOInterruptState)
    { IOSimpleLockUnlock(l); }
inline void unlock(IOSimpleLockInterrupt* sli, IOInterruptState is)
{
    auto *sl = reinterpret_cast<IOSimpleLock*>(sli);
    IOSimpleLockUnlockEnableInterrupt(sl, is);
}

// Lock assertions
// TODO(gvdl): Hack alert, for some reason IOLocks.h doesn't publish the
// assertions out of XNU_KERNEL_PRIVATE. Until it is published I have copy and
// pasted the implementaiton, radar to follow.
inline void assertLocked(IOLock* l)
    { LCK_MTX_ASSERT(static_cast<lck_mtx_t*>(l), LCK_ASSERT_OWNED); }
inline void assertLocked(IORecursiveLock* l)
    { assert(IORecursiveLockHaveLock(l)); }
inline void assertLocked(IOSimpleLock* l)
    { LCK_SPIN_ASSERT(static_cast<lck_spin_t*>(l), LCK_ASSERT_OWNED); }
inline void assertLocked(IOSimpleLockInterrupt* l)
    { LCK_SPIN_ASSERT(reinterpret_cast<lck_spin_t*>(l), LCK_ASSERT_OWNED); }

inline void assertUnlocked(IOLock* l)
    { LCK_MTX_ASSERT(static_cast<lck_mtx_t*>(l), LCK_ASSERT_NOTOWNED); }
inline void assertUnlocked(IORecursiveLock* l)
    { assert(!IORecursiveLockHaveLock(l)); }
inline void assertUnlocked(IOSimpleLock* l)
    { LCK_SPIN_ASSERT(static_cast<lck_spin_t*>(l), LCK_ASSERT_NOTOWNED); }
inline void assertUnlocked(IOSimpleLockInterrupt* l)
    { LCK_SPIN_ASSERT(reinterpret_cast<lck_spin_t*>(l), LCK_ASSERT_NOTOWNED); }
}; // namespace locking_primitives

// LockGuard works as is with IOLock, IORecursiveLock, IOSimpleLock and the
// above IOSimpleLockInterrupt (which disables interrupts too).
template <class _L>
class LockGuard {
public:
    LockGuard() = delete;
    LockGuard(const LockGuard&) = delete;
    LockGuard& operator=(const LockGuard&) = delete;
    LockGuard& operator=(LockGuard&& other) = delete;

    explicit LockGuard(_L* lock) : fLock(lock)
    {
        fMostlyUnusedInterruptState = locking_primitives::lock(fLock);
    }
    ~LockGuard()
    {
        if (static_cast<bool>(fLock))
            locking_primitives::unlock(fLock, fMostlyUnusedInterruptState);
    }
    LockGuard(LockGuard&& other) : fLock(other.fLock) { other.fLock = 0; }
private:
    _L* fLock = nullptr;
    IOInterruptState fMostlyUnusedInterruptState = 0;
};

_IOG_END_NAMESPACE

#endif // !IOG_TL_IOLOCKS
